package de.kp.wsclient.security;

import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import javax.crypto.Cipher;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;

import org.apache.xml.security.algorithms.JCEMapper;
import org.apache.xml.security.utils.Base64;
import org.w3c.dom.Attr;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NamedNodeMap;
import org.w3c.dom.Node;
import org.w3c.dom.Text;

import de.kp.wsclient.soap.SOAP11Constants;
import de.kp.wsclient.soap.SOAP12Constants;
import de.kp.wsclient.soap.SOAPConstants;
import de.kp.wsclient.util.UUIDGenerator;

public class SecUtil {

    /*
     * The default wsu:Id allocator is a simple "start at 1 and increment up"
     * thing that is very fast.
     */

    private static WsuIdAllocator idAllocator = new WsuIdAllocator() {
 
    	int i;
        private synchronized String next() {
            return Integer.toString(++i);
        }
        public String createId(String prefix, Object o) {
            if (prefix == null) {
                return next();
            }
            return prefix + next();
        }

        public String createSecureId(String prefix, Object o) {
            if (prefix == null) {
                return UUIDGenerator.getUUID();
            }
            return prefix + UUIDGenerator.getUUID();
        }
    };

    // A cached MessageDigest object
    private static MessageDigest digest = null;
    
    public static WsuIdAllocator getIdAllocator() {
        return idAllocator;
    }

    /**
     * @param startElement
     * @return
     */
    public static SOAPConstants getSOAPConstants(Element startElement) {

    	Document doc = startElement.getOwnerDocument();
        String ns = doc.getDocumentElement().getNamespaceURI();
        
        if (SecConstants.URI_SOAP12_ENV.equals(ns)) {
            return new SOAP12Constants();
        }
        
        return new SOAP11Constants();
    
    }

    /**
     * This method create a base64 test node.
     * 
     * @param doc
     * @param data
     * @return
     */
    public static Text createBase64EncodedTextNode(Document doc, byte data[]) {
        return doc.createTextNode(Base64.encode(data));
    }
   
    /**
     * @param startElement
     * @return
     */
    public static String getSOAPNamespace(Element startElement) {
        return getSOAPConstants(startElement).getEnvelopeURI();
    }

    /**
     * @param parent
     * @param child
     * @return
     */
    public static Element prependChildElement(Element parent, Element child) {
        
    	Node firstChild = parent.getFirstChild();
        if (firstChild == null) {
            return (Element)parent.appendChild(child);
        
        } else {
            return (Element)parent.insertBefore(child, firstChild);
        
        }
    }

    /**
     * Generate a (SHA1) digest of the input bytes. The MessageDigest 
     * instance that backs this method is cached for efficiency.  
     * 
     * @param inputBytes
     * @return
     * @throws Exception
     */
    public static synchronized byte[] generateDigest(byte[] inputBytes) throws Exception {
        
    	try {
        
    		if (digest == null) digest = MessageDigest.getInstance("SHA-1");
    		return digest.digest(inputBytes);
        
    	} catch (Exception e) {
            throw new Exception("[SecUtil] Error in generating digest");
            
        }
    }

    /**
     * Set a namespace/prefix on an element if it is not set already. 
     * First off, it searches for the element for the prefix associated 
     * with the specified namespace. If the prefix isn't null, then this 
     * is returned. Otherwise, it creates a new attribute using the 
     * namespace/prefix passed as parameters.
     * 
     * @param element
     * @param namespace
     * @param prefix
     * @return
     */
    public static String setNamespace(Element element, String namespace, String prefix) {
    
    	String pre = getPrefixNS(namespace, element);
        if (pre != null) {
            return pre;
        }
        
        element.setAttributeNS(SecConstants.XMLNS_NS, "xmlns:" + prefix, namespace);
        return prefix;
    
    }

    /**
     * 
     * @param uri
     * @param e
     * @return
     */
    public static String getPrefixNS(String uri, Node e) {

    	while (e != null && (e.getNodeType() == Element.ELEMENT_NODE)) {
        
    		NamedNodeMap attrs = e.getAttributes();
            for (int n = 0; n < attrs.getLength(); n++) {
            
            	Attr a = (Attr) attrs.item(n);
                String name = a.getName();
                
                if (name.startsWith("xmlns:") && a.getNodeValue().equals(uri)) {
                    return name.substring(6);
                }
            
            }
            
            e = e.getParentNode();
        
    	}
        return null;
    }
   
    /**
     * This method returns the first soap "Body" element.
     * 
     * @param doc
     * @return
     */
    public static Element findBodyElement(Document doc) {
    
    	//
        // Find the SOAP Envelope NS. Default to SOAP11 NS
        //
        Element docElement = doc.getDocumentElement();
        String ns = docElement.getNamespaceURI();
        return getDirectChildElement(docElement, SecConstants.ELEM_BODY, ns);

    }

    // Gets a direct child with specified localname and namespace. <p/>

    public static Element getDirectChildElement(Node parentNode, String localName, String namespace) {
        
    	if (parentNode == null) {
            return null;
        }
        
    	for (Node currentChild = parentNode.getFirstChild();  currentChild != null;  currentChild = currentChild.getNextSibling()) {
            
    		if (Node.ELEMENT_NODE == currentChild.getNodeType() && localName.equals(currentChild.getLocalName()) && namespace.equals(currentChild.getNamespaceURI())) {
                return (Element)currentChild;
            }
        }
    	
        return null;
    }

    /*
     * Find the DOM Element in the SOAP Envelope that is referenced by the 
     * WSEncryptionPart argument. 
     * 
     * The "Id" is used before the Element localname/namespace.
     */

    public static List<Element> findElements(SecEncPart part, CallbackLookup callbackLookup, Document doc) throws Exception {

    	// See if the DOM Element is stored in the WSEncryptionPart first
        if (part.getElement() != null) {
            return Collections.singletonList(part.getElement());
        }
        
        // Next try to find the Element via its wsu:Id
        String id = part.getId();
        if (id != null) {
            Element foundElement = callbackLookup.getElement(id, null, false);
            return Collections.singletonList(foundElement);
        }
        
        // Otherwise just lookup all elements with the localname/namespace
        return callbackLookup.getElements(part.getName(), part.getNamespace());
    }

    /*
     * Returns all elements that match name and namespace. 
     */
 
    public static List<Element> findElements(Node startNode, String name, String namespace) {

    	// Replace the formerly recursive implementation with a 
    	// depth-first-loop lookup

    	if (startNode == null) {
            return null;
        }
        
    	Node startParent = startNode.getParentNode();
        Node processedNode = null;

        List<Element> foundNodes = new ArrayList<Element>();
        while (startNode != null) {
 
        	// start node processing at this point
            if (startNode.getNodeType() == Node.ELEMENT_NODE && startNode.getLocalName().equals(name)) {
 
            	String ns = startNode.getNamespaceURI();
                if (ns != null && ns.equals(namespace)) {
                    foundNodes.add((Element)startNode);
                }

                if ((namespace == null || namespace.length() == 0) && (ns == null || ns.length() == 0)) {
                    foundNodes.add((Element)startNode);
                }
            }
 
            processedNode = startNode;
            startNode = startNode.getFirstChild();

            // no child, this node is done.
            if (startNode == null) {
                // close node processing, get sibling
                startNode = processedNode.getNextSibling();
            }
            
            // no more siblings, get parent, all children
            // of parent are processed.
            while (startNode == null) {
                processedNode = processedNode.getParentNode();
                if (processedNode == startParent) {
                    return foundNodes;
                }
                // close parent node processing (processed node now)
                startNode = processedNode.getNextSibling();
            }
        }
        
        return foundNodes;
    }

    /*
     * Returns the single element that contains an Id with value
     * uri and namespace. The Id can be either a wsu:Id or an Id
     * with no namespace. This is a replacement for a XPath Id 
     * lookup with the given namespace. 
     * 
     * It's somewhat faster than XPath, and we do not deal with 
     * prefixes, just with the real namespace URI
     * 
     * If checkMultipleElements is true and there are multiple 
     * elements, we log a warning and return null as this can 
     * be used to get around the signature checking.
     */

    public static Element findElementById(Node startNode, String value, boolean checkMultipleElements) {

    	//
        // Replace the formerly recursive implementation with a depth-first-loop lookup
        //
        Node startParent = startNode.getParentNode();
        Node processedNode = null;
        Element foundElement = null;
        String id = getIDFromReference(value);

        while (startNode != null) {
            // start node processing at this point
            if (startNode.getNodeType() == Node.ELEMENT_NODE) {
                Element se = (Element) startNode;
                // Try the wsu:Id first
                String attributeNS = se.getAttributeNS(SecConstants.WSU_NS, "Id");
                if ("".equals(attributeNS) || !id.equals(attributeNS)) {
                    attributeNS = se.getAttributeNS(null, "Id");
                }
                if (!"".equals(attributeNS) && id.equals(attributeNS)) {

                	if (!checkMultipleElements) {
                        return se;
                    
                	} else if (foundElement == null) {
                        foundElement = se; // Continue searching to find duplicates
                    
                	} else {
                        return null;
                    }
                }
            }

            processedNode = startNode;
            startNode = startNode.getFirstChild();

            // no child, this node is done.
            if (startNode == null) {
                // close node processing, get sibling
                startNode = processedNode.getNextSibling();
            }
            // no more siblings, get parent, all children
            // of parent are processed.
            while (startNode == null) {
                processedNode = processedNode.getParentNode();
                if (processedNode == startParent) {
                    return foundElement;
                }
                // close parent node processing (processed node now)
                startNode = processedNode.getNextSibling();
            }
        }
        return foundElement;
    }

    /*
     * Turn a reference (eg "#5") into an ID (eg "5").
     */
 
    public static String getIDFromReference(String ref) {
 
    	String id = ref.trim();
        if (id.length() == 0) {
            return null;
        }
        if (id.charAt(0) == '#') {
            id = id.substring(1);
        }
        return id;
    }

    /*
     * Translate the "cipherAlgo" URI to a JCE ID, and return a javax.crypto.Cipher instance
     * of this type. 
     */
    public static Cipher getCipherInstance(String cipherAlgo) throws Exception {

    	try {
            String keyAlgorithm =  JCEMapper.translateURItoJCEID(cipherAlgo);
            return Cipher.getInstance(keyAlgorithm);
        
    	} catch (NoSuchPaddingException ex) {
            throw new Exception("[SecUtil] Unsupported algorithm.");
        
    	} catch (NoSuchAlgorithmException ex) {

    		// Check to see if an RSA OAEP MGF-1 with SHA-1 algorithm was requested
            // Some JDKs don't support RSA/ECB/OAEPPadding

    		if (SecConstants.KEYTRANSPORT_RSAOEP.equals(cipherAlgo)) {
                try {
                    return Cipher.getInstance("RSA/ECB/OAEPWithSHA1AndMGF1Padding");
            
                } catch (Exception e) {
                    throw new Exception("[SecUtil] Unsupported algorithm.");
                }
            
            } else {
                throw new Exception("[SecUtil] Unsupported algorithm.");

            }
        
    	}
    }

    public static String getKeyAlgorithm(String symEncAlgo) {

    	String keyAlgorithm = JCEMapper.getJCEKeyAlgorithmFromURI(symEncAlgo);
        
        if (keyAlgorithm == null || "".equals(keyAlgorithm)) {
            keyAlgorithm = JCEMapper.translateURItoJCEID(symEncAlgo);
        }
        
        return keyAlgorithm;
        
    }
    
    /*
     * Convert the raw key bytes into a SecretKey object of type symEncAlgo.
     */
 
    public static SecretKey prepareSecretKey(String symEncAlgo, byte[] rawKey) {
    
    	// Do an additional check on the keysize required by the encryption algorithm
        int size = 0;
        try {
            size = JCEMapper.getKeyLengthFromURI(symEncAlgo) / 8;
        
        } catch (Exception e) {
            // ignore - some unknown (to JCEMapper) encryption algorithm
        }

        String keyAlgorithm = JCEMapper.getJCEKeyAlgorithmFromURI(symEncAlgo);
        SecretKeySpec keySpec;
        
        if (size > 0) {
            keySpec = new SecretKeySpec(rawKey, 0, ((rawKey.length > size) ? size : rawKey.length), keyAlgorithm);
        
        } else {
            keySpec = new SecretKeySpec(rawKey, keyAlgorithm);
        }
        
        return (SecretKey)keySpec;
    
    }

    public static boolean isContent(Node encBodyData) {

    	if (encBodyData != null) {
        
    		String typeStr = ((Element)encBodyData).getAttribute("Type");
            if (typeStr != null) {
                 return typeStr.equals(SecConstants.ENC_NS + "Content");
            }
        }
    	
        return true;
    
    }

}
